#### set fixed effects ####
m_fv <- c("nutsyear", "season", "event") 
m_fv2 <- c("unit", "nuts0year", "season", "event") # alternative
eb_fv <- c("nutsyear", "season") 
eb_fv2 <- c("unit", "nuts0year", "season")  # alternative
# cluster
env_cv <- c("nuts0year")

#### make formula for fixest ####
f_feols <- function(yvar, xvars = 0, fevars = 0) {
  formula <- as.formula(
    paste(
      yvar, "~", 
      paste(xvars, collapse = "+"), "|", 
      paste(fevars, collapse = "+")
    )
  )
  return(formula)
}

#### make formula for lm ####
f_lm <- function(yvar, xvars = 0) {
  formula <- as.formula(
    paste(
      yvar, "~", 
      paste(xvars, collapse = "+"),
      "-1"
    )
  )
  return(formula)
}

#### texreg default arguments ####
my_texreg <- partial(
  texreg::texreg, include.adjrs = FALSE, include.fstatistic = FALSE, include.groups = TRUE,
  digits = 3, table = FALSE, booktabs = TRUE, use.packages = FALSE, dcolumn = TRUE, stars = c(0.01, 0.05, 0.1),
  custom.model.names = paste0("(", 1:length(tab_export), ")")
)

my_screenreg <- partial(
  texreg::screenreg, include.adjrs = FALSE, include.fstatistic = FALSE, include.groups = TRUE,
  digits = 3, table = FALSE, booktabs = TRUE, use.packages = FALSE, dcolumn = TRUE, stars = c(0.01, 0.05, 0.1), 
  # custom.gof.names = c("Observations", "Region-year fixed effects (groups)", "Season fixed effects (groups)", 
  #                      "Event type fixed effects (groups)", "R² (overall)", "R² (within)"),
  # custom.coef.map	= temp_4w_map,
  custom.model.names = paste0("(", 1:length(tab_export), ")")
)

#### tidy mediate ####
tidy_med <- function(x){
  tmp <- tibble(
    # term = c("$\\eta$: Average mediation effect", "$\\zeta$: Average direct effect", "$\\tau$: Total effect", "Proportion mediated"),
    term = c("Average mediation effect", "Average direct effect", "Total effect", "Proportion mediated"),
    estimate = c(x$d0, x$z0, x$tau.coef, x$n.avg),
    ci_low = c(x$d0.ci[1], x$z0.ci[1], x$tau.ci[1], x$n.avg.ci[1]),
    ci_high = c(x$d0.ci[2], x$z0.ci[2], x$tau.ci[2], x$n.avg.ci[2]),
    p = c(x$d0.p, x$z0.p, x$tau.p, x$n.avg.p),
    mediator = x$mediator,
    treat = x$treat
  )
  tmp2 <- 
    tmp %>% 
    mutate(
      stars = case_when(
        between(p, 0.05, 0.1-1e-5) ~ '*',
        between(p, 0.01, 0.05-1e-5) ~ '**',
        between(p, 0, 0.01-1e-5) ~ '***',
        TRUE ~ ""
      )
    )
}

tidy_medsens <- function(x) {
  tmp <- c(
    `$\rho$` = x$err.cr.d,
    `$R^2_{M*}R^2_{Y*}$` = x$R2star.d.thresh,
    `$R^2_{M~}R^2_{Y~}$` = x$R2tilde.d.thresh
  )
  return(
    round(tmp, 4)
  )
}

#### functions for distributed lag models ####
make_lc_plot <- function(x, dep_vars=emm_dep_vars, ind_vars=emm_ind_vars, conf_level=0.99){
  lc_df <- tibble()
  for (i in seq_along(ind_vars)) {
    for (j in seq_along(dep_vars)){
      tmp <- 
        test_cum_effect(
          model = x, dep_var = dep_vars[j], ind_var = ind_vars[i], 
          lag_start = min(lags), conf = conf_level
        ) %>% 
        mutate(
          dep_var_name = names(dep_vars[j]),
          ind_var_name = names(ind_vars[i])
        )
      lc_df <- lc_df %>% bind_rows(tmp)
    }
  }
  lc_plot <- 
    lc_df %>% 
    mutate(
      dep_var_name = factor(dep_var_name, levels = unique(names(dep_vars))),
      ind_var_name = factor(ind_var_name, levels = unique(names(ind_vars)))
    )
  return(lc_plot)
}
plot_lc <- function(x, file, y_on="climate news") {
  plot_height <- (x$ind_var_name %>% unique %>% length) *1.5
  x %>% 
    ggplot(aes(x = lag, y = estimate)) +
    geom_hline(yintercept = 0, alpha = 0.5) +
    geom_vline(xintercept = 0, alpha = 0.5) +
    geom_ribbon(aes(ymin = conf.low, ymax = conf.high), alpha = 0.2) +
    geom_line(linewidth = 1) +
    geom_point(data = filter(x, p.value < 0.01), size = 2) +
    scale_x_continuous(breaks = -10:10, minor_breaks = NULL) +
    facet_grid(rows = vars(ind_var_name), cols = vars(dep_var_name), scales = "free_y") +
    labs(
      # title = title,
      y = paste0("Standardized cumulative effect on ", y_on),
      x = "Weeks since event"
    ) 
  ggsave(file, height = plot_height)
}

test_cum_effect <- function(model, dep_var, ind_var, ind_var2=NA, lag_start, conf=0.95) {
  if (class(model) == "fixest_multi") {
    i_model <- grepl(dep_var, names(model)) %>% which
    model_col <- model[[i_model]]
  } else if (class(model) == "fixest") {
    model_col <- model
  } else {
    error("Model should be of class fixest or fixest_multi.")
  }
  coef_names <- model_col %>% coef %>% names
  start1 <- grepl(paste0("f(", ind_var, ", 2)"), coef_names, fixed=T) %>% which
  if (!is.na(ind_var2)){
    start2 <- grepl(paste0("f(", ind_var2, ", 2)"), coef_names, fixed=T) %>% which
    lc_acc <- tibble()
    for (i in 0:10) {
      lc_tmp <- hypotheses(
        model_col, 
        paste0("`", paste(c(coef_names[start1:(start1+i)], coef_names[start2:(start2+i)]), collapse = "` + `"), "` = 0"),
        conf_level = conf 
      ) %>% 
        mutate(lag = i+lag_start)
      lc_acc <- bind_rows(lc_acc, lc_tmp)
      }
  } else {
    lc_acc <- tibble()
    for (i in 0:10) {
      lc_tmp <- hypotheses(
        model_col, 
        paste0("`", paste(coef_names[start1:(start1+i)], collapse = "` + `"), "` = 0"),
        conf_level = conf 
      ) %>% 
        mutate(lag = i+lag_start)
      lc_acc <- bind_rows(lc_acc, lc_tmp)
    }
    
    }
  return(lc_acc)
}


plot_dl <- function(df, title=NULL) {
  df %>% 
    ggplot(aes(x = lag, y = estimate)) +
    geom_ribbon(aes(ymin = conf.low, ymax = conf.high), alpha = 0.3, fill = "cadetblue") +
    geom_line(linewidth = 1) +
    geom_hline(yintercept = 0) +
    geom_vline(xintercept = 0) +
    scale_x_continuous(breaks = -10:10, minor_breaks = NULL) +
    labs(
      title = title, 
      y = "Cumulative effect",
      x = "Weeks since events"
    )
}

#### tidy distributed lags ####
tidy_dl <- function(x) {
    lapply(x, function(x){cbind(lhs = x$lhs, tidy(x))} ) %>% 
    bind_rows() %>% 
    filter(!grepl("event_name", term)) %>%
    mutate(
      ci_low = estimate - 1.96 * std.error,
      ci_high = estimate + 1.96 * std.error,
      # Topic = case_when(
      #   grepl("climate_w", lhs) ~ "Topic: Climate change",
      #   grepl("climate_ext_w", lhs) ~ "Topic: Climate change (ext)",
      #   grepl("env", lhs) ~ "Topic: Environment & Ecology"
      # ),
      Publication = factor(case_when(
        grepl("reg", lhs) ~ "Regional climate news",
        grepl("nat", lhs) ~ "National climate news"
      ), levels = c("Regional climate news", "National climate news")
      ),
      Type = if_else(grepl("(q10)|(q05)|(q025)|(neg)", term), "Cold spell", "Warm spell"), #for plot
      type_tab = case_when( #for plot
        grepl("(q10)|(q05)|(q025)|(neg)", term) ~ "Cold spell",
        grepl("(q90)|(q95)|(q975)|(pos)", term) ~ "Warm spell",
        .default = "Temp. anomaly"),
      lag = case_when(
        grepl("_L2$", term) ~ -2,
        grepl("_L1", term) ~ -1,
        grepl("_l1$", term) ~ 1,
        grepl("_l2$", term) ~ 2,
        .default = 0
      ), 
      var = factor(case_when(
        grepl("hs", term) ~ "Warm spell severity",
        grepl("xf", term) ~ "Warm spell duration",
        grepl("nf_pos", term) ~ "Pos. temperature anomaly",
        grepl("nf_neg", term) ~ "Neg. temperature anomaly",
        grepl("nf", term) ~ "Temperature anomaly",
      ), levels = c("Temperature anomaly", "Pos. temperature anomaly", "Neg. temperature anomaly", "Warm spell severity", "Warm spell duration")),
      threshold = case_when(
        grepl("(05)|(95)", term) ~ "5%",
        grepl("(025)|(975)", term) ~ "2.5%",
        grepl("(10)|(90)", term) ~ "10%"
      ),
      var_treshold = paste0(var, " (", threshold, ")")
    )
}

#### plot distributed lags ####
plot_dl <- function(x) {
  x %>%
    filter(
      Type == "Warm spell", 
      !is.na(var),
      # Topic == "Topic: Climate change"
    ) %>% 
    ggplot(aes(x = lag, y = estimate)) +
    geom_pointrange(
      aes(ymin = ci_low, ymax = ci_high, shape = Publication),#, colour = Type, shape = Publication
      position = position_dodge(width = 0.75)
    ) +
    geom_hline(yintercept = 0, colour = "grey60") +  
    labs(
      y = "Standardized effect on news coverage", 
      x = "Lag (4-week periods)"
    ) + 
    scale_colour_brewer(palette = "Set1", direction = -1) +
    facet_wrap(vars(var), nrow = 1) +
    # facet_grid(cols = vars(Publication), rows = vars(Topic)) +
    theme(legend.position = "bottom", legend.title = element_blank()) 
}



